{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 【迁移学习】往一个已经保存好的 模型添加新的变量\n",
    "在迁移学习中，通常我们已经训练好一个模型，现在需要修改模型的部分结构，用于我们的新任务。\n",
    "\n",
    "比如：\n",
    "\n",
    "在一个图片分类任务中，我们使用别人训练好的网络来提取特征，但是我们的分类数目和原模型不同，这样我们只能取到 fc 层，后面的分类层需要重新写。这样我们就需要添加新的变量。那么这些新加入的变量必须得初始化才能使用。可是我们又不能使用 'tf.global_variables_initializer()' 来初始化，否则原本训练好的模型就没用了。\n",
    "\n",
    "关于怎么样单独初始化新增的变量，可以参考下面两个链接：\n",
    "- [1]https://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables\n",
    "- [2]https://stackoverflow.com/questions/35013080/tensorflow-how-to-get-all-variables-from-rnn-cell-basiclstm-rnn-cell-multirnn\n",
    "\n",
    "简单的例子可以直接看下面我的实现方式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 初始模型定义与保存\n",
    "首先定义一个模型，里边有 v1, v2 两个变量，我们把这个模型保存起来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in file: ../ckpt/test-model.ckpt-1\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "# Create some variables.\n",
    "v1 = tf.Variable([1.0, 2.3], name=\"v1\")\n",
    "v2 = tf.Variable(55.5, name=\"v2\")\n",
    "\n",
    "\n",
    "init_op = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()\n",
    "ckpt_path = '../ckpt/test-model.ckpt'\n",
    "sess.run(init_op)\n",
    "save_path = saver.save(sess, ckpt_path, global_step=1)\n",
    "print(\"Model saved in file: %s\" % save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**restart(重启kernel然后执行下面cell的代码)**\n",
    "### 导入已经保存好的模型，并添加新的变量\n",
    "现在把之前保存好的模型，我只需要其中的 v1 变量。同时我还要添加新的变量 v3。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ../ckpt/test-model.ckpt-1\n",
      "[1.  2.3]\n",
      "[1.  2.3]\n",
      "666\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'../ckpt/test-model.ckpt-2'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "# Create some variables.\n",
    "v1 = tf.Variable([11.0, 16.3], name=\"v1\")\n",
    "\n",
    "# ** 导入训练好的模型\n",
    "saver = tf.train.Saver()\n",
    "ckpt_path = '../ckpt/test-model.ckpt'\n",
    "saver.restore(sess, ckpt_path + '-'+ str(1))\n",
    "print(sess.run(v1))\n",
    "\n",
    "# ** 定义新的变量并单独初始化新定义的变量\n",
    "v3 = tf.Variable(666, name='v3', dtype=tf.int32)\n",
    "init_new = tf.variables_initializer([v3])\n",
    "sess.run(init_new)\n",
    "# 。。。这里就可以进行 fine-tune 了\n",
    "print(sess.run(v1))\n",
    "print(sess.run(v3))\n",
    "\n",
    "# ** 保存新的模型。 \n",
    "#  注意！注意！注意！ 一定一定一定要重新定义 saver, 这样才能把 v3 添加到 checkpoint 中\n",
    "saver = tf.train.Saver()\n",
    "saver.save(sess, ckpt_path, global_step=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**restart(重启kernel然后执行下面cell的代码)**\n",
    "### 这样就完成了 fine-tune, 得到了新的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ../ckpt/test-model.ckpt-2\n",
      "Model restored.\n",
      "[1.  2.3]\n",
      "666\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "# Create some variables.\n",
    "v1 = tf.Variable([11.0, 16.3], name=\"v1\")\n",
    "v3 = tf.Variable(666, name='v3', dtype=tf.int32)\n",
    "# Add ops to save and restore all the variables.\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "# Later, launch the model, use the saver to restore variables from disk, and\n",
    "# do some work with the model.\n",
    "# Restore variables from disk.\n",
    "ckpt_path = '../ckpt/test-model.ckpt'\n",
    "saver.restore(sess, ckpt_path + '-'+ str(2))\n",
    "print(\"Model restored.\")\n",
    "\n",
    "print(sess.run(v1))\n",
    "print(sess.run(v3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
